
import numpy as	np
import time
import sys
import matplotlib.pyplot as	plt
import os

#--------------------------------#
from FwdBwdNeuralEqV3 import *
from Tx	import *
from eq	import *
from neuralEQ import *
from utils import *
import device



if __name__	== "__main__":
	#*************************HEADER***********************#
	startTime =	time.time()
	np.random.seed(1)
	args = parsing_def()
	sys.path.insert(0, './config')
	config_module =	__import__('config_{}'.format(args.config))
	cfg= config_module.config

	if cfg['eval']['mod'] == 'nrz':
		modNum = 2
	elif cfg['eval']['mod'] ==	'pam4':
		modNum = 4
	elif cfg['eval']['mod'] ==	'pam8':
		modNum = 8
	else:
		sys.exit('invalid modulation')

	delay =	int((cfg['eval']['inSize'])/4)
	delayOffset = -list(cfg['eval']['chSBR']).index(max(cfg['eval']['chSBR']))
	simName = args.name
	#******************************************************#

	tx = Tx(mod=cfg['eval']['mod'])
	for modelFile in cfg['eval']['modelFileList']:
		nEqLoad = torch.load(modelFile)
		nEqLoad = nEqLoad.to(device.device)
		berTestHis = []
		print(f"Model loaded ... ({modelFile})")
		for snrTest in cfg['eval']['snrTestList']:
			chInValid = tx.run(int(cfg['eval']['dataSizeValid']))
			chValid = Channel(sbr=cfg['eval']['chSBR'], snr=snrTest)
			chOutValid = chValid.run(chIn = chInValid, flagN=cfg['eval']['noiseFlag'])
			print(f"Eval start for snr: {snrTest}", flush=True)
			chInTest = tx.run(int(cfg['eval']['dataSizeTest']))
			chTest = Channel(sbr=cfg['eval']['chSBR'], snr=snrTest)
			chOutTest = chTest.run(chIn = chInTest, flagN=cfg['eval']['noiseFlag'])
			if cfg['eval']['finetune']:
				opt = torch.optim.Adam(
										nEqLoad.parameters(), 
										lr=cfg['eval']['lr'], 
										weight_decay=cfg['eval']['weightDecay'],
										)
				trainLossHis, validLossHis, validBerHis = trainEval(
																nEqLoad,
																tx,
																chInValid,
																chOutValid,
																cfg['eval']['numEpoch'],
																cfg['eval']['evalFreq'],
																cfg['eval']['mod'],
																cfg['eval']['chSBR'],
																cfg['eval']['inSize'],
																cfg['eval']['outSize'],
																cfg['eval']['batchSize'],
																delay+delayOffset,
																cfg['eval']['lossFn'],
																opt,
																int(cfg['eval']['dataSizeTrain']),
																snrTest,
																cfg['eval']['noiseFlag'],
															)
			simNeq = simNeuralEQ(
								txDataTrain = None,
								rxDataTrain = None,
								txDataTest = chInTest,
								rxDataTest = chOutTest,
								neuralEQ = nEqLoad,
								mod = cfg['eval']['mod'],
								)

			testLoss, berTest = simNeq.evalNeuralEQ(
													cfg['eval']['lossFn'],
													batchSize = cfg['eval']['batchSize'],
													inSize = cfg['eval']['inSize'],
													outSize = cfg['eval']['outSize'],
													delay = delay+delayOffset
  	              									)
			print(f"berTest: {berTest} @ snr:{snrTest}",flush=True)
			berTestHis.append(berTest)
		print(f"berTestHis@{modelFile}")
		print(f"{berTestHis}")